import mindspore as ms
from mindformers.tools.register import MindFormerConfig
from mindformers.models import LlamaConfig, LlamaForCausalLM, LlamaTokenizer

ms.set_context(device_target="Ascend", device_id=3, mode=0)
config = MindFormerConfig('/home/zhangsenzhen/2023Q2/mindformers/configs/llama/run_llama_13b_ziya.yaml')
model_config = LlamaConfig(**config.model.model_config)

model = LlamaForCausalLM(model_config)
tokenizer = LlamaTokenizer(config.processor.tokenizer.vocab_file)

while True:
    print("input:", end='')
    query = input()
    inputs = query#'<human>:' + query.strip() + '\n<bot>:'
    
    input_ids = tokenizer(inputs, add_special_tokens=False)
    generate_ids = model.generate(
                input_ids["input_ids"],
                max_length=512,
                do_sample=False,
                top_k=1,
                eos_token_id=2,
                pad_token_id=0)[0]

    output = tokenizer.decode(generate_ids)
    print(output)

